#----------------------------------------------------------------------
#  GFDM method test - 2d perforated plate
#  Transient heat with Neumann BCs only (no Dirichlet, no Robin)
#  Author: Andrea Pavan
#  Date: 03/02/2023
#  License: GPLv3-or-later
#----------------------------------------------------------------------
using ElasticArrays;
using LinearAlgebra;
using SparseArrays;
using PyPlot;
include("utils.jl");


#problem definition
Dfuel = 1.27e-2;        #fuel rod diameter
Dcoolant = 1.587e-2;        #coolant rod diameter
L = 1.879e-2;       #distance between fuel and coolant centers
q1 = 10000;     #bottom heat flux
q2 = -5000;     #top heat flux
k = 10;     #thermal conductivity
ρ = 1;      #material density
cp = 1e6;       #specific heat capacity
T0 = 0;     #initial temperature
t0 = 0;     #starting time
Δt = 0.1;       #timestep
tf = 3.0;       #ending time

meshSize = 0.02e-2;
surfaceMeshSize = meshSize;
minNeighbors = 10;
minSearchRadius = meshSize;


#read pointcloud from a SU2 file
time1 = time();
pointcloud = ElasticArray{Float64}(undef,2,0);      #2xN matrix containing the coordinates [X;Y] of each node
boundaryNodes = Vector{Int}(undef,0);       #indices of the boundary nodes
internalNodes = Vector{Int}(undef,0);       #indices of the internal nodes
normals = ElasticArray{Float64}(undef,2,0);     #2xN matrix containing the components [nx;ny] of the normal of each boundary node

pointcloud = parseSU2mesh("15d_2d_transient_heat_perforated_plate_neumann_only.su2");
cornerPoint = findall((pointcloud[2,:].<=1e-6).*(pointcloud[1,:].<=1e-6));
pointcloud = pointcloud[:, setdiff(1:end,cornerPoint)];
cornerPoint = findall((pointcloud[2,:].>=0.5*sqrt(3)*L-1e-6).*(pointcloud[1,:].>=0.5*L-1e-6));
pointcloud = pointcloud[:, setdiff(1:end,cornerPoint)];
N = size(pointcloud,2);
idxC = 0;       #index of the reference point
for i=1:N
    if pointcloud[1,i]^2+pointcloud[2,i]^2<=(Dfuel/2)^2+1e-6
        #AC segment
        push!(boundaryNodes, i);
        append!(normals, [-pointcloud[1,i],-pointcloud[2,i]]./(Dfuel/2));
        if pointcloud[1,i]<=Dfuel/2+1e6 && pointcloud[2,i]<=0+1e-6
            global idxC = i;
            println("Reference point C index: ",idxC);
        end
    elseif (pointcloud[1,i]-0.5*L)^2+(pointcloud[2,i]-0.5*sqrt(3)*L)^2<=(Dcoolant/2)^2+1e-6
        #ED segment
        push!(boundaryNodes, i);
        append!(normals, [-(pointcloud[1,i]-0.5*L),-(pointcloud[2,i]-0.5*sqrt(3)*L)]./(Dcoolant/2));
    elseif pointcloud[2,i]<=0+1e-6
        #AB segment
        push!(boundaryNodes, i);
        append!(normals, [0,-1]);
    elseif pointcloud[1,i]>=0.5*L-1e-6
        #BE segment
        push!(boundaryNodes, i);
        append!(normals, [1,0]);
    elseif pointcloud[2,i]>=sqrt(3)*pointcloud[1,i]-1e-6
        #CD segment
        push!(boundaryNodes, i);
        append!(normals, [-0.5*sqrt(3),0.5]);
    else
        push!(internalNodes, i);
        append!(normals, [0,0]);
    end
end
println("Generated pointcloud in ", round(time()-time1,digits=2), " s");
println("Pointcloud properties:");
println("  Boundary nodes: ",length(boundaryNodes));
println("  Internal nodes: ",length(internalNodes));
println("  Memory: ",memoryUsage(pointcloud,boundaryNodes));


#normals plot
#=import Plots;
plt1Idx = boundaryNodes;
figure();
plt1 = Plots.quiver(pointcloud[1,plt1Idx],pointcloud[2,plt1Idx],
        quiver = (normals[1,plt1Idx]/1000,normals[2,plt1Idx]/1000),
        title = "Normals plot",
        xlim = (-0.002,0.012),
        ylim = (-0.002,0.012),
        aspect_ratio = :equal);
display(plt1);=#


#boundary conditions
N = size(pointcloud,2);     #number of nodes
g1 = zeros(Float64,N);
g2 = zeros(Float64,N);
g3 = zeros(Float64,N);
for i in boundaryNodes
    if pointcloud[1,i]^2+pointcloud[2,i]^2<=(Dfuel/2)^2+1e-6
        #AC segment (bottom)
        g1[i] = 0.0;
        g2[i] = k;
        g3[i] = q1;
    elseif (pointcloud[1,i]-0.5*L)^2+(pointcloud[2,i]-0.5*sqrt(3)*L)^2<=(Dcoolant/2)^2+1e-6
        #ED segment (top)
        g1[i] = 0.0;
        g2[i] = k;
        g3[i] = q2;
    else
        #everywhere else
        g1[i] = 0.0;
        g2[i] = 1.0;
        g3[i] = 0.0;
    end
end


#neighbor search
time2 = time();
neighbors = Vector{Vector{Int}}(undef,N);       #vector containing N vectors of the indices of each node neighbors
Nneighbors = zeros(Int,N);      #number of neighbors of each node
for i=1:N
    searchradius = minSearchRadius;
    while Nneighbors[i]<minNeighbors
        neighbors[i] = Int[];
        #check every other node
        for j=1:N
            if j!=i && all(abs.(pointcloud[:,j]-pointcloud[:,i]).<searchradius)
                push!(neighbors[i],j);
            end
        end
        unique!(neighbors[i]);
        Nneighbors[i] = length(neighbors[i]);
        searchradius += minSearchRadius/2;
    end
end
println("Found neighbors in ", round(time()-time2,digits=2), " s");
println("Connectivity properties:");
println("  Max neighbors: ",maximum(Nneighbors)," (at index ",findfirst(isequal(maximum(Nneighbors)),Nneighbors),")");
println("  Avg neighbors: ",round(sum(Nneighbors)/length(Nneighbors),digits=2));
println("  Min neighbors: ",minimum(Nneighbors)," (at index ",findfirst(isequal(minimum(Nneighbors)),Nneighbors),")");


#neighbors distances and weights
time3 = time();
P = Vector{Array{Float64}}(undef,N);        #relative positions of the neighbors
r2 = Vector{Vector{Float64}}(undef,N);      #relative distances of the neighbors
w = Vector{Vector{Float64}}(undef,N);       #neighbors weights
for i=1:N
    P[i] = Array{Float64}(undef,2,Nneighbors[i]);
    r2[i] = Vector{Float64}(undef,Nneighbors[i]);
    w[i] = Vector{Float64}(undef,Nneighbors[i]);
    for j=1:Nneighbors[i]
        P[i][:,j] = pointcloud[:,neighbors[i][j]]-pointcloud[:,i];
        r2[i][j] = P[i][:,j]'P[i][:,j];
    end
    r2max = maximum(r2[i]);
    for j=1:Nneighbors[i]
        w[i][j] = exp(-6*r2[i][j]/r2max);
        #w[i][j] = 1.0;
    end
end
wpde = 2.0;       #least squares weight for the pde
wbc = 2.0;        #least squares weight for the boundary condition


#least square matrix inversion
C = Vector{Matrix}(undef,N);        #derivatives coefficients matrices
for i in internalNodes
    xj = P[i][1,:];
    yj = P[i][2,:];
    V = zeros(Float64,1+Nneighbors[i],6);
    for j=1:Nneighbors[i]
        V[j,:] = [1, xj[j], yj[j], xj[j]^2, yj[j]^2, xj[j]*yj[j]];
    end
    V[1+Nneighbors[i],:] = [0, 0, 0, 2*k/(ρ*cp), 2*k/(ρ*cp), 0];
    W = Diagonal(vcat(w[i],wpde));
    (Q,R) = qr(W*V);
    C[i] = inv(R)*transpose(Matrix(Q))*W;
end
for i in boundaryNodes
    xj = P[i][1,:];
    yj = P[i][2,:];
    V = zeros(Float64,2+Nneighbors[i],6);
    for j=1:Nneighbors[i]
        V[j,:] = [1, xj[j], yj[j], xj[j]^2, yj[j]^2, xj[j]*yj[j]];
    end
    V[1+Nneighbors[i],:] = [0, 0, 0, 2*k/(ρ*cp), 2*k/(ρ*cp), 0];
    V[2+Nneighbors[i],:] = [g1[i], g2[i]*normals[1,i], g2[i]*normals[2,i], 0, 0, 0];
    W = Diagonal(vcat(w[i],wpde,wbc));
    (Q,R) = qr(W*V);
    C[i] = inv(R)*transpose(Matrix(Q))*W;
end
println("Inverted least-squares matrices in ", round(time()-time3,digits=2), " s");


#matrix assembly
time4 = time();
rows = Int[];
cols = Int[];
vals = Float64[];
for i=1:N
    push!(rows, i);
    push!(cols, i);
    push!(vals, 1-C[i][1,1+Nneighbors[i]]/Δt);
    for j=1:Nneighbors[i]
        push!(rows, i);
        push!(cols, neighbors[i][j]);
        push!(vals, -C[i][1,j]);
    end
end
M = sparse(rows,cols,vals,N,N);
println("Completed matrix assembly in ", round(time()-time4,digits=2), " s");


#time propagation
time5 = time();
uprev = T0*ones(N);     #solution at the previous step
uC = [];
for t=t0:Δt:tf
    b = zeros(N);       #rhs vector
    for i in internalNodes
        b[i] = -C[i][1,1+Nneighbors[i]]*uprev[i]/Δt;
    end
    for i in boundaryNodes
        b[i] = -C[i][1,1+Nneighbors[i]]*uprev[i]/Δt + C[i][1,2+Nneighbors[i]]*g3[i];
    end
    u = M\b;
    push!(uC, uprev[idxC])
    println("t = ",t,"; uC = ",uC[end]);
    global uprev = u;
end
println("Simulation completed in ", round(time()-time5,digits=2), " s");

#solution plot
figure();
plt = scatter(pointcloud[1,:],pointcloud[2,:],c=uprev,cmap="jet");
title("Numerical solution");
axis("equal");
colorbar(plt);
display(gcf());

#validation plot
figure();
plot(collect(t0:Δt:tf),uC,"r-",linewidth=1.0,label="GFDM");
plot([0.1, 0.2, 0.35, 0.55, 0.75, 0.95, 1.15, 1.35, 1.55, 1.75, 1.95, 2.15, 2.35, 2.55, 2.75, 2.95, 3],[0.9308771, 1.373251, 1.833629, 2.321856, 2.745425, 3.126666, 3.476643, 3.802391, 4.109068, 4.400678, 4.680385, 4.950698, 5.213602, 5.470667, 5.723131, 5.97197, 6.033968],"k.",label="FEM");
title("Temperature evolution - node C");
legend(loc="upper left");
xlabel("Time t");
ylabel("Temperature TC");
display(gcf());
